syntax = "proto3";

option cc_enable_arenas = true;

package tensorflow.boosted_trees.learner;

// Tree regularization config.
message TreeRegularizationConfig {
  // Classic L1/L2.
  float l1 = 1;
  float l2 = 2;

  // Tree complexity penalizes overall model complexity effectively
  // limiting how deep the tree can grow in regions with small gain.
  float tree_complexity = 3;
}

// Tree constraints config.
message TreeConstraintsConfig {
  // Maximum depth of the trees. The default value is 6 if not specified.
  uint32 max_tree_depth = 1;

  // Min hessian weight per node.
  float min_node_weight = 2;

  // Maximum number of unique features used in the tree. Zero means there is no
  // limit.
  int64 max_number_of_unique_feature_columns = 3;
}

// LearningRateConfig describes all supported learning rate tuners.
message LearningRateConfig {
  oneof tuner {
    LearningRateFixedConfig fixed = 1;
    LearningRateDropoutDrivenConfig dropout = 2;
    LearningRateLineSearchConfig line_search = 3;
  }
}

// Config for a fixed learning rate.
message LearningRateFixedConfig {
  float learning_rate = 1;
}

// Config for a tuned learning rate.
message LearningRateLineSearchConfig {
  // Max learning rate. Must be strictly positive.
  float max_learning_rate = 1;

  // Number of learning rate values to consider between [0, max_learning_rate).
  int32 num_steps = 2;
}

// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent
// weights updates in functional space, and thus we can use averaging of weight
// updates to achieve better performance. For example, we can say that our final
// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1
// and tree 2 etc .. ensemble of all trees.
// Note that this averaging will apply ONLY DURING PREDICTION. The training
// stays the same.
message AveragingConfig {
  oneof config {
    float average_last_n_trees = 1;
    // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1,
    // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5,
    // last half of the trees are averaged etc.
    float average_last_percent_trees = 2;
  }
}

message LearningRateDropoutDrivenConfig {
  // Probability of dropping each tree in an existing so far ensemble.
  float dropout_probability = 1;

  // When trees are built after dropout happen, they don't "advance" to the
  // optimal solution, they just rearrange the path. However you can still
  // choose to skip dropout periodically, to allow a new tree that "advances"
  // to be added.
  // For example, if running for 200 steps with probability of dropout 1/100,
  // you would expect the dropout to start happening for sure for all iterations
  // after 100. However you can add probability_of_skipping_dropout of 0.1, this
  // way iterations 100-200 will include approx 90 iterations of dropout and 10
  // iterations of normal steps.Set it to 0 if you want just keep building
  // the refinement trees after dropout kicks in.
  float probability_of_skipping_dropout = 2;

  // Between 0 and 1.
  float learning_rate = 3;
}

message LearnerConfig {
  enum PruningMode {
    PRUNING_MODE_UNSPECIFIED = 0;
    PRE_PRUNE = 1;
    POST_PRUNE = 2;
  }

  enum GrowingMode {
    GROWING_MODE_UNSPECIFIED = 0;
    WHOLE_TREE = 1;
    LAYER_BY_LAYER = 2;
  }

  enum MultiClassStrategy {
    MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
    TREE_PER_CLASS = 1;
    FULL_HESSIAN = 2;
    DIAGONAL_HESSIAN = 3;
  }

  enum WeakLearnerType {
    NORMAL_DECISION_TREE = 0;
    OBLIVIOUS_DECISION_TREE = 1;
  }

  // Number of classes.
  uint32 num_classes = 1;

  // Fraction of features to consider in each tree sampled randomly
  // from all available features.
  oneof feature_fraction {
    float feature_fraction_per_tree = 2;
    float feature_fraction_per_level = 3;
  };

  // Regularization.
  TreeRegularizationConfig regularization = 4;

  // Constraints.
  TreeConstraintsConfig constraints = 5;

  // Pruning. POST_PRUNE is the default pruning mode.
  PruningMode pruning_mode = 8;

  // Growing Mode. LAYER_BY_LAYER is the default growing mode.
  GrowingMode growing_mode = 9;

  // Learning rate. By default we use fixed learning rate of 0.1.
  LearningRateConfig learning_rate_tuner = 6;

  // Multi-class strategy. By default we use TREE_PER_CLASS for binary
  // classification and linear regression. For other cases, we use
  // DIAGONAL_HESSIAN as the default.
  MultiClassStrategy multi_class_strategy = 10;

  // If you want to average the ensembles (for regularization), provide the
  // config below.
  AveragingConfig averaging_config = 11;

  // By default we use NORMAL_DECISION_TREE as weak learner.
  WeakLearnerType weak_learner_type = 12;
}
